from __future__ import print_function
import torch
import torch.distributions
import torch.utils.data
import numpy as np
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss


def ten2ar(tensor):
    return tensor.detach().cpu().numpy()


def add_n_dims(generalized_tensor, n, dim=-1):
    """ Adds n new dimensions of size 1 to the end of the tensor or array """
    for i in range(n):
        generalized_tensor = torch.unsqueeze(generalized_tensor, dim)
    return generalized_tensor


def broadcast_final(t1, t2):
    """ Adds trailing dimensions to t1 to match t2 """
    return add_n_dims(t1, len(t2.shape) - len(t1.shape))


def unpackbits(ten, dim=-1):
    """ Applies np.unpackbits """
    assert ten.dtype == torch.uint8
    arr = ten2ar(ten)
    bit_arr = np.unpackbits(np.expand_dims(arr, dim), dim)
    return torch.from_numpy(bit_arr).to(ten.device)


def combine_dim(x, dim_begin, dim_end=None):
    if dim_end is None: dim_end = len(x.shape)
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.reshape(combined_shape)


def packbits(ten, dim=-1):
    """ Applies np.unpackbits """
    assert ten.dtype == torch.uint8
    arr = ten2ar(ten)
    bit_arr = np.packbits(arr, dim).squeeze(dim)
    return torch.from_numpy(bit_arr).to(ten.device)


def find_extra_dim(smaller, larger):
    """ This function finds the position of extra dimension in two tensors that only differ by one dimension

    :param smaller: a tensor
    :param larger: a tensor that the same shape as smaller, except for one extra dimension that can be anywhere
    :return: the integer index of the extra dimension
    """
    shape_smaller = np.array(smaller.shape)
    shape_larger = np.array(larger.shape)
    assert len(shape_smaller) + 1 == len(shape_larger)
    
    # First different index
    diff_idx = (shape_smaller == shape_larger[:-1]).argmin()
    assert (shape_smaller[diff_idx:] == shape_larger[diff_idx + 1:]).all()
    
    return int(diff_idx)


class AttrDict(dict):
    __setattr__ = dict.__setitem__

    def __getattr__(self, attr):
        # Take care that getattr() raises AttributeError, not KeyError.
        # Required e.g. for hasattr(), deepcopy and OrderedDict.
        try:
            return dict.__getitem__(self, attr)
        except KeyError:
            raise AttributeError("Attribute %r not found" % attr)


class Distribution():
    def nll(self, x):
        raise NotImplementedError
    
    def sample(self, x):
        raise NotImplementedError


class LocScaleDistribution(Distribution):
    def __init__(self, mu, log_sigma=None, sigma=None, concat_dim=-1):
        """

        :param mu: the mean. this parameter should have the shape of the desired distribution
        :param log_sigma: If none, mu is divided into two chunks, mu and log_sigma
        """
        if log_sigma is None and sigma is None:
            if not isinstance(mu, torch.Tensor):
                import pdb;
                pdb.set_trace()
            mu, log_sigma = torch.chunk(mu, 2, concat_dim)
        
        self.mu = mu
        self._log_sigma = log_sigma
        self._sigma = sigma
        self.concat_dim = concat_dim
    
    @property
    def sigma(self):
        if self._sigma is None:
            self._sigma = self._log_sigma.exp()
        return self._sigma
    
    @property
    def log_sigma(self):
        if self._log_sigma is None:
            self._log_sigma = self._sigma.log()
        return self._log_sigma


class Gaussian(LocScaleDistribution):
    """ Represents a gaussian distribution """
    
    def sample(self):
        return self.mu + self.sigma * torch.randn_like(self.mu)
    
    def nll(self, x):
        # Negative log likelihood (probability)
        return 0.5 * torch.pow((x - self.mu) / self.sigma, 2) + self.log_sigma + 0.5 * np.log(2 * np.pi)
    
    def optimal_variance_nll(self, x):
        """ Computes the NLL of a gaussian with the optimal (constant) variance for these data """
        
        sigma = ((x - self.mu) ** 2).mean().sqrt()
        return Gaussian(mu=self.mu, sigma=sigma).nll(x)
    
    @property
    def mean(self):
        return self.mu
    

class DiscreteLogistic(Distribution):
    """ IAFVAE """
    
    def __init__(self, mu, log_sigma, range=None):
        self.mu = mu
        self.log_sigma = log_sigma
    
    def cdf(self, x):
        return torch.sigmoid(x)
    
    def prob(self, x):
        # Return the probability mass
        mean = self.mu
        logscale = self.log_sigma
        binsize = 1 / 256.0
    
        mask_bottom = x == 0
        mask_top = x == 1
    
        scale = torch.exp(logscale)
        x = (torch.floor(x / binsize) * binsize - mean) / scale
    
        p = self.cdf(x + binsize / scale) - self.cdf(x)
    
        # Edge cases
        p_bottom = self.cdf(x + binsize / scale)
        p[mask_bottom] = p_bottom[mask_bottom]
        p_top = 1 - self.cdf(x)
        p[mask_top] = p_top[mask_top]
        
        return p
    
    def nll(self, x):
        p = self.prob(x)
        
        # Add epsilon for stability
        return -(p + 1e-7).log()
    
    @property
    def mean(self):
        return self.mu


class DiscreteLogisticMixture(DiscreteLogistic):
    def __init__(self, mu, log_sigma, n=5):
        """
        
        :param mu:
        :param log_sigma:
        :param n: number of elements in the mixture
        """
        sh = mu.shape
        self.mu = mu.reshape(sh[0:1] + (n, -1) + sh[2:])
        self.log_sigma = log_sigma
        if isinstance(log_sigma, torch.Tensor) and log_sigma.shape == mu.shape:
            self.log_sigma = self.log_sigma.reshape(sh[0:1] + (n, -1) + sh[2:])
        
    def nll(self, x):
        # mu, log_sigma: mixture x batch x dims
        n = self.mu.shape[1]
        p = self.prob(x[:, None][:, [0] * n]).mean(1)
        
        return -(p + 1e-7).log()
    
    @property
    def mean(self):
        return self.mu.mean(1)


class Beta(Distribution):
    """ A Beta distribution defined on [0,1]"""
    
    def __init__(self, b, a):
        self.b = b
        self.a = a
        self.base_dist = torch.distributions.Beta(b, a)

    def nll(self, x, eps=1e-7):
        # return -self.base_dist.log_prob(x)
        
        a_term = (self.a - 1) * (x + eps).log()
        b_term = (self.b - 1) * (1 - x + eps).log()
        log_prob = a_term + b_term - self.log_norm

        return -log_prob

    @property
    def mean(self):
        # TODO for reasons I do not entirely understand, the image is flipped here.
        return 1 - self.base_dist.mean

    @property
    def log_norm(self):
        return self.base_dist._log_normalizer(self.base_dist.concentration0, self.base_dist.concentration1)


class SmartCrossEntropyLoss(CrossEntropyLoss):
    """ This is a helper class that automatically finds which dimension is the classification dimension
    (as opposed to it always being dim=1) """
    
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # Find the dimension that has the distribution
        diff_idx = find_extra_dim(target, input)
        shape_target = target.shape
        target = target.reshape((-1,) + target.shape[diff_idx:])
        input = input.view((-1,) + input.shape[diff_idx:])
        
        loss = super().forward(input, target)
        
        if self.reduction == 'none':
            loss = loss.view(tuple(shape_target[:diff_idx]) + loss.shape[1:])
        
        return loss


class Categorical(Distribution):
    def __init__(self, p=None, log_p=None):
        # TODO log_p is actually unnormalized in most cases
        assert p is None or log_p is None
        
        self._log_p = log_p
        self._p = p
    
    @property
    def p(self):
        if self._p is not None:
            return self._p
        elif self._log_p is not None:
            # TODO use pytorch implementation?
            return self._log_p.exp() / self._log_p.exp().sum(1, keepdim=True)
    
    @property
    def log_p(self):
        if self._p is not None:
            return self._p.log()
        elif self._log_p is not None:
            return self._log_p
    
    def nll(self, x):
        if self._log_p is not None:
            return SmartCrossEntropyLoss(reduction='none')(self._log_p, x.round().long())


class Bernoulli(Categorical):
    def nll(self, x):
        if self._log_p is not None:
            return F.binary_cross_entropy_with_logits(self._log_p, x, reduction='none')
    
    @property
    def p(self):
        if self._p is not None:
            return self._p
        elif self._log_p is not None:
            return self._log_p.sigmoid()


class ImageBitwiseCategorical(Bernoulli):
    """ This is useful to represent a bitwise distribution - a distribution over each bit in a tensor

    It should be initialized with a tensor batch x bits x image_dims, where bits=8, the number of bits needed to
    describe each channel value.
    """
    
    def nll(self, x):
        dim = find_extra_dim(x, self.log_p)
        x_bitwise = unpackbits(((x + 1) * 127.5).round().byte(), dim)
        
        nll = super().nll(x_bitwise.float())
        return combine_dim(nll, dim, dim + 2)
    
    @property
    def mle(self):
        log_p = self.log_p
        bits_mle = log_p > 0
        mle = packbits(bits_mle.byte(), 1).float()
        return mle / 127.5 - 1
    
    @property
    def mean(self):
        template = torch.tensor([128, 64, 32, 16, 8, 4, 2, 1])
        p = self.p
        value = broadcast_final(template.to(p.device).float()[None], p)
        return (p * value.float()).sum(1) / 127.5 - 1


class ImageCategorical(Categorical):
    """ This converts the input image from -1..1 to 0..255. This is useful to represent an image as a categorical
    distribution over all pixel values. It is factorized over colors and spatial locations.

    It should be initialized with a tensor batch x pixel_values x image_dims, where pixel_valuse=256, the number of
    different values a pixel is allowed to take."""
    
    def nll(self, x):
        return super().nll((x + 1) * 127.5)
    
    @property
    def mle(self):
        return self.log_p.argmax(1).float() / 127.5 - 1
    
    @property
    def mean(self):
        template = torch.arange(256)
        p = self.p
        value = broadcast_final(template.to(p.device).float()[None], p)
        return (p * value.float()).sum(1) / 127.5 - 1

